import math
from math import sqrt
import argparse
from pathlib import Path
from unittest import TestCase

# torch

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes and utils

from dalle_pytorch import distributed_utils
# from dalle_pytorch import DiscreteVAE
from dalle_pytorch.dalle_pytorch_ori import DiscreteVAE, VQProgram
# from dalle_pytorch.dalle_pytorch_oriema import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_ae import DiscreteVAE

# argument parsing

import sys
sys.path.insert(0, '/home/tiangel/DALLE_3D/Learning-to-Group')
from IPython import embed
import glob
# from pytorch3d.io import load_ply
from pytorch3d.io import load_ply, save_ply
from torch.utils.data import Dataset
import os
from partnet.utils.torch_pc import normalize_points as normalize_points_torch

from pytorch3d.io import IO
from pytorch3d.structures import Pointclouds
# from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras, 
    PointsRasterizationSettings,
    PointsRenderer,
    PulsarPointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor
)
import matplotlib.pyplot as plt
import numpy as np
from geometry_utils import render_pts, rotate_pts, render_pts_with_label

# sys.path.insert(0,'/home/tiangel/DALLE_newest/UnsupervisedPointCloudReconstruction')
# from UnsupervisedPointCloudReconstruction.model import FoldNet_Decoder7
from dalle_pytorch.transformer import Transformer
import torch.nn as nn

sys.path.insert(0, '/home/tiangel/DALLE_3D/shape2prog')
from shape2prog.dataset import Synthesis3D
import torch.nn.functional as F

parser = argparse.ArgumentParser()

parser.add_argument('--image_folder', type = str, required = True,
                    help='path to your folder of images for learning the discrete VAE and its codebook')

parser.add_argument('--image_size', type = int, required = False, default = 128,
                    help='image size')

parser = distributed_utils.wrap_arg_parser(parser)


train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--vae_path', type=str,
                   help='path to your trained discrete VAE')

train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs')

train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size')

train_group.add_argument('--learning_rate', type = float, default = 1e-3, help = 'learning rate')

train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay')

train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature')

train_group.add_argument('--temp_min', type = float, default = 0.5, help = 'minimum temperature to anneal to')

train_group.add_argument('--anneal_rate', type = float, default = 1e-6, help = 'temperature annealing rate')

train_group.add_argument('--num_images_save', type = int, default = 2, help = 'number of images to save')

model_group = parser.add_argument_group('Model settings')

model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens')

model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)')

model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks')

model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true')

model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension')

model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension')

model_group.add_argument('--dim1', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--dim2', type = int, default = 32, help = 'hidden dimension')

model_group.add_argument('--final_points', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--radius', type = float, default = 0.3, help = 'hidden dimension')

model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')

model_group.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

model_group.add_argument('--aug', type = bool, default = True, help = 'KL loss weight')

model_group.add_argument('--testae', type = bool, default = False, help = 'KL loss weight')

args = parser.parse_args()

# constants

IMAGE_SIZE = args.image_size
IMAGE_PATH = args.image_folder

EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
LEARNING_RATE = args.learning_rate
LR_DECAY_RATE = args.lr_decay_rate

NUM_TOKENS = args.num_tokens
NUM_LAYERS = args.num_layers
NUM_RESNET_BLOCKS = args.num_resnet_blocks
SMOOTH_L1_LOSS = args.smooth_l1_loss
EMB_DIM = args.emb_dim
HIDDEN_DIM = args.hidden_dim
KL_LOSS_WEIGHT = args.kl_loss_weight

STARTING_TEMP = args.starting_temp
TEMP_MIN = args.temp_min
ANNEAL_RATE = args.anneal_rate

NUM_IMAGES_SAVE = args.num_images_save

# initialize distributed backend

distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

using_deepspeed = \
    distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)

# data

class PC_Dataset(Dataset):
    def __init__(self, path):
        self.data_dir = path
        self.data_list = glob.glob(os.path.join(self.data_dir, '*.ply'))
        self.len = len(self.data_list)
        self.do_aug = args.aug

    def __getitem__(self, index):
        pc = load_ply(self.data_list[index])
        points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
        if self.do_aug:
            scale = points.new(1).uniform_(0.9, 1.05)
            points[:, 0:3] *= scale
        return (points, pc[1])

    def __len__(self):
        return self.len

# ds = PC_Dataset(IMAGE_PATH)
ds = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5',10)
# ds = ImageFolder(
#     IMAGE_PATH,
#     T.Compose([
#         T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#         T.Resize(IMAGE_SIZE),
#         T.CenterCrop(IMAGE_SIZE),
#         T.ToTensor()
#     ])
# )

if distributed_utils.using_backend(distributed_utils.HorovodBackend):
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds, num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
else:
    data_sampler = None

dl = DataLoader(ds, BATCH_SIZE, shuffle = False, drop_last=True)

loaded_obj = torch.load(os.path.join('./outputs/vae_models',args.vae_path))
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(
    **vae_params,
)

vae.load_state_dict(weights)
vae.eval().cuda()

assert len(ds) > 0, 'folder does not contain any images'
if distr_backend.is_root_worker():
    print(f'{len(ds)} images found for training')

# optimizer

vqprogram = VQProgram().cuda().train()

opt = Adam(vqprogram.parameters(), lr = LEARNING_RATE)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = opt, T_max = EPOCHS*int(len(ds)/BATCH_SIZE))
# sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)


if distr_backend.is_root_worker():
    # weights & biases experiment tracking

    import wandb

    model_config = dict(
        num_tokens = NUM_TOKENS,
        smooth_l1_loss = SMOOTH_L1_LOSS,
        num_resnet_blocks = NUM_RESNET_BLOCKS,
        kl_loss_weight = KL_LOSS_WEIGHT
    )

    run = wandb.init(
        project = 'dalle_train_vae',
        job_type = 'train_model',
        config = model_config
    )

# distribute

distr_backend.check_batch_size(BATCH_SIZE)
deepspeed_config = {'train_batch_size': BATCH_SIZE}

(distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute(
    args=args,
    model=vae,
    optimizer=opt,
    model_parameters=vae.parameters(),
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=sched if not using_deepspeed else None,
    config_params=deepspeed_config,
)

using_deepspeed_sched = False
# Prefer scheduler in `deepspeed_config`.
if distr_sched is None:
    distr_sched = sched
elif using_deepspeed:
    # We are using a DeepSpeed LR scheduler and want to let DeepSpeed
    # handle its scheduling.
    using_deepspeed_sched = True


def save_model(path):

    save_obj = {
        'weights': vqprogram.state_dict()
    }

    torch.save(save_obj, path)

# starting temperature

def render_pytorch3d(renderer, pts, count, name):
    rgb=torch.zeros(pts.shape).cuda()
    rgb[:,1]=0.5
    point_cloud = Pointclouds(points=[pts], features=[rgb])

    rendered_img = renderer(point_cloud, gamma=(1e-4,))
    rendered_img[rendered_img == 0] = 1
    plt.figure(figsize=(10, 10))
    plt.imshow(rendered_img[0, ..., :3].detach().cpu().numpy())
    plt.axis("off");
    plt.savefig(os.path.join(save_dir, '%04d'%count+'_'+name+'.png'),dpi=300)
    plt.close()

global_step = 0
temp = STARTING_TEMP
save_dir = os.path.join('./outputs/vae_outputs','test'+args.save_name)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
count = 0
cd_loss_list = []
emd_loss_list = []
vq_loss_list = []
perplexity_list = []

def to_contiguous(tensor):
    if tensor.is_contiguous():
        return tensor
    else:
        return tensor.contiguous()

for epoch in range(EPOCHS):
    for i, data in enumerate(distr_dl):
        shapes, labels, masks, params, param_masks = data[0], data[1], data[2], data[3], data[4]
        input = normalize_points_torch(shapes).cuda()
        with torch.no_grad():
            vq_encoding = distr_vae.get_encoding(input).squeeze(-1)
        # vq_encoding: [8, 512, 128]
        out = vqprogram(vq_encoding.permute(0,2,1).contiguous())
        
        bsz, n_block, n_step = labels.size()
        out_pgm = out[:,:30*22].reshape(bsz, 30, 22)
        out_pgm = F.log_softmax(out_pgm, dim=-1)

        # compute program classification loss
        labels = labels.contiguous().view(bsz, n_block * n_step)
        masks = masks.contiguous().view(bsz, n_block * n_step)
        pred = to_contiguous(out_pgm).view(-1, out_pgm.size(2))
        target = to_contiguous(labels).view(-1,1).cuda()
        mask = to_contiguous(masks).view(-1,1).cuda()
        loss_cls = - pred.gather(1, target) * mask
        loss_cls = torch.sum(loss_cls) / torch.sum(mask)
        _, idx = torch.max(pred, dim=1)
        correct = idx.eq(torch.squeeze(target))
        correct = correct.float() * torch.squeeze(mask)
        acc = torch.sum(correct) / torch.sum(mask)


        # compute parameter regression loss
        bsz, n_block, n_step, n_param = params.size()
        out_param = out[:,30*22:].reshape(bsz, n_block * n_step, n_param)
        params = params.contiguous().view(bsz, n_block * n_step, n_param).cuda()
        param_masks = param_masks.contiguous().view(bsz, n_block * n_step, n_param).cuda()
        diff = 0.5 * (out_param - params) ** 2
        diff = diff * param_masks
        loss_reg = torch.sum(diff) / torch.sum(param_masks)


        loss = loss_cls + 3*loss_reg
        if i % 50 == 0:
            print('loss_cls:%.3f, acc:%.3f, loss_reg:%.3f'%(loss_cls, acc, loss_reg))

        opt.zero_grad()
        loss.backward()
        opt.step()
        sched.step()


        if i % 1000 == 0:
            save_model(f'./outputs/vae_models/vqprogram'+args.save_name+'.pt')

save_model('./outputs/vae_models/vqprogram-final'+args.save_name+'.pt')



